xtask\tasks\fmt/
rustfmt.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::Xtask;
5use crate::fs_helpers::git_diffed;
6use clap::Parser;
7use std::path::PathBuf;
8use xshell::cmd;
9
10#[derive(Parser)]
11#[clap(about = "Check that all repo files are formatted using rustfmt")]
12pub struct Rustfmt {
13    /// Run `rustfmt` on all `.rs` files in the repo
14    #[clap(long)]
15    pub fix: bool,
16
17    /// A list of files to check
18    ///
19    /// If no files were provided, all files in-tree will be checked
20    pub files: Vec<PathBuf>,
21
22    /// Only run checks on files that are currently diffed
23    #[clap(long, conflicts_with = "files")]
24    pub only_diffed: bool,
25}
26
27impl Rustfmt {
28    pub fn new(fix: bool, only_diffed: bool) -> Self {
29        Self {
30            fix,
31            files: Vec::new(),
32            only_diffed,
33        }
34    }
35}
36
37#[derive(Debug)]
38enum Files {
39    All,
40    OnlyDiffed,
41    Specific(Vec<PathBuf>),
42}
43
44impl Xtask for Rustfmt {
45    fn run(self, ctx: crate::XtaskCtx) -> anyhow::Result<()> {
46        let files = if self.only_diffed {
47            Files::OnlyDiffed
48        } else if self.files.is_empty() {
49            Files::All
50        } else {
51            Files::Specific(self.files)
52        };
53
54        log::trace!("running rustfmt on {:?}", files);
55
56        let sh = xshell::Shell::new()?;
57        let rust_toolchain = sh.var("RUST_TOOLCHAIN").map(|s| format!("+{s}")).ok();
58        let fmt_check = (!self.fix).then_some("--check");
59
60        match files {
61            Files::All => {
62                cmd!(sh, "cargo {rust_toolchain...} fmt -- {fmt_check...}")
63                    .quiet()
64                    .run()?;
65            }
66            Files::OnlyDiffed => {
67                let mut files = git_diffed(ctx.in_git_hook)?;
68                files.retain(|f| f.extension().unwrap_or_default() == "rs");
69
70                if !files.is_empty() {
71                    let res = cmd!(sh, "rustfmt {rust_toolchain...} {fmt_check...} {files...}")
72                        .quiet()
73                        .run();
74
75                    if res.is_err() {
76                        anyhow::bail!("found formatting issues in diffed files");
77                    }
78                }
79            }
80            Files::Specific(files) => {
81                assert!(!files.is_empty());
82
83                cmd!(sh, "rustfmt {rust_toolchain...} {fmt_check...} {files...}")
84                    .quiet()
85                    .run()?;
86            }
87        }
88
89        log::trace!("done rustfmt");
90        Ok(())
91    }
92}